<script src="../dist/browser/tfjs-onnx.js"></script>
<script src="utils.js"></script>
<script src="draw.js"></script>

<style>
  #draw {
    border: 1px dotted black;
    cursor: crosshair;
  }
</style>

<div>
  <canvas width="240" height="240" id="draw"></canvas>
</div>
<div>
  <button onclick="draw.clean()">Reset</button>
</div>
<div>
  <canvas id="img"></canvas>
</div>
<div id="result"></div>

<script>
  const imgUrl = 'img/mnist_5.png';
  const modelUrl = 'models/mnist/mnist.onnx';
  const drawElem = document.getElementById('draw');

  const draw = makeCanvasDrawable(drawElem);
  draw.start();

  (async function(){
    // Load the model
    const model = await tf.onnx.loadModel(modelUrl);
    console.debug(model);

    drawElem.addEventListener('mouseup', async () => {

      // Load the pixel and scale to model input
      const data = getImageData(drawElem, {
        shape: [28, 28], gray: true
      });

      // Visualize image
      displayImage(data, 'img');

      // Perform the forward pass and return the prediction
      const predictions = tf.tidy(() => {
        const logits = model.predict(data);
        return logits.as1D().softmax();
      });

      console.log(predictions.dataSync());

      // Output the result
      const topK = await getTopKClasses(predictions, 5);

      const probs = topK[0];
      const classNames = Array.from(topK[1]);

      displayLabel(probs, classNames, 'result');

    }, false );
  })();

</script>
